{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load MXNet model\n",
    "\n",
    "In this tutorial, you learn how to load an existing MXNet model and use it to run a prediction task.\n",
    "\n",
    "\n",
    "## Preparation\n",
    "\n",
    "This tutorial requires the installation of Java Kernel. For more information on installing the Java Kernel, see the [README](https://docs.djl.ai/docs/demos/jupyter/index.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n",
    "\n",
    "%maven ai.djl:api:0.28.0\n",
    "%maven ai.djl:model-zoo:0.28.0\n",
    "%maven ai.djl.mxnet:mxnet-engine:0.28.0\n",
    "%maven ai.djl.mxnet:mxnet-model-zoo:0.28.0\n",
    "%maven org.slf4j:slf4j-simple:1.7.36"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import java.awt.image.*;\n",
    "import java.nio.file.*;\n",
    "import ai.djl.*;\n",
    "import ai.djl.inference.*;\n",
    "import ai.djl.ndarray.*;\n",
    "import ai.djl.modality.*;\n",
    "import ai.djl.modality.cv.*;\n",
    "import ai.djl.modality.cv.util.*;\n",
    "import ai.djl.modality.cv.transform.*;\n",
    "import ai.djl.modality.cv.translator.*;\n",
    "import ai.djl.translate.*;\n",
    "import ai.djl.training.util.*;\n",
    "import ai.djl.util.*;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Prepare your MXNet model\n",
    "\n",
    "This tutorial assumes that you have a MXNet model trained using Python. A MXNet symbolic model usually contains the following files:\n",
    "* Symbol file: {MODEL_NAME}-symbol.json - a json file that contains network information about the model\n",
    "* Parameters file: {MODEL_NAME}-{EPOCH}.params - a binary file that stores the parameter weight and bias\n",
    "* Synset file: synset.txt - an optional text file that stores classification classes labels\n",
    "\n",
    "This tutorial uses a pre-trained MXNet `resnet18_v1` model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use `DownloadUtils` for downloading files from internet."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DownloadUtils.download(\"https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/resnet/0.0.1/resnet18_v1-symbol.json\", \"build/resnet/resnet18_v1-symbol.json\", new ProgressBar());\n",
    "DownloadUtils.download(\"https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/resnet/0.0.1/resnet18_v1-0000.params.gz\", \"build/resnet/resnet18_v1-0000.params\", new ProgressBar());\n",
    "DownloadUtils.download(\"https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/synset.txt\", \"build/resnet/synset.txt\", new ProgressBar());\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Load your model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Path modelDir = Paths.get(\"build/resnet\");\n",
    "Model model = Model.newInstance(\"resnet\");\n",
    "model.load(modelDir, \"resnet18_v1\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Create a `Translator`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Pipeline pipeline = new Pipeline();\n",
    "pipeline.add(new CenterCrop()).add(new Resize(224, 224)).add(new ToTensor());\n",
    "Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()\n",
    "                .setPipeline(pipeline)\n",
    "                .optSynsetArtifactName(\"synset.txt\")\n",
    "                .optApplySoftmax(true)\n",
    "                .build();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4: Load image for classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "var img = ImageFactory.getInstance().fromUrl(\"https://resources.djl.ai/images/kitten.jpg\");\n",
    "img.getWrappedImage()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5: Run inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Predictor<Image, Classifications> predictor = model.newPredictor(translator);\n",
    "Classifications classifications = predictor.predict(img);\n",
    "\n",
    "classifications"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "Now, you can load any MXNet symbolic model and run inference.\n",
    "\n",
    "You might also want to check out [load_pytorch_model](https://docs.djl.ai/docs/demos/jupyter/load_pytorch_model.html) which demonstrates loading a local model using the ModelZoo API."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
